## Solvers

import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import random
import csv

from utils import *

def sgd(x, A, B, gamma, sigma, num_iter=1000, xtol=0.0):
    # baseline
    num_nodes, num_dim = B.shape
    x_iter = np.copy(x)
    xdists = [xdistance(x_iter, A, B)]
    fdists = [fdistance(x_iter, A, B)]
    f_star = min_f(A, B)
    x_star = argmin_f_noinv(A, B)
    for i in range(0, num_iter):
        sgrad = stoch_gradient(x_iter, A, B, sigma)
        x_iter = x_iter - gamma * np.mean(sgrad, axis=0)
        fdist = fval(x_iter, A, B) - f_star
        xdist = np.linalg.norm(x_iter-x_star) ** 2 / num_dim
        xdists += [xdist]
        fdists += [fdist]
        if xdist < xtol:
            break
    return fdists, xdists, x_iter
def sgd_wrapper(x, A, B, compressor, delta, gamma, eta, sigma, num_iter, xtol):
    return sgd(x, A, B, gamma=gamma, sigma=sigma,  num_iter=num_iter, xtol=xtol)

def econtrol(x, A, B, compressor, delta, gamma, eta, sigma, num_iter=1000, xtol=0.0):
    # x.shape = (num_dim,)
    # A.shape = (num_nodes, num_dim, num_dim)
    # B.shape = (num_nodes, num_dim)
    def comp(vec):
        return compressor(vec, delta)
    num_nodes, num_dim = B.shape
    x_iter = np.copy(x)
    xdists = [xdistance(x_iter, A, B)]
    fdists = [fdistance(x_iter, A, B)]
    E_iter = np.zeros_like(B)
    H_iter = np.zeros_like(B)
    f_star = min_f(A, B)
    x_star = argmin_f_noinv(A, B)
    for i in range(0, num_iter):
        sgrad = stoch_gradient(x_iter, A, B, sigma) # shape (num_nodes, num_dim)
        if i == 0:
            H_iter = H_iter + sgrad
        Delta = np.apply_along_axis(comp, 1, eta * E_iter + sgrad - H_iter)
        E_iter = E_iter + sgrad - H_iter - Delta
        x_iter = x_iter - gamma * np.mean(H_iter + Delta, axis = 0)
        H_iter = H_iter + Delta
        fdist = fval(x_iter, A, B) - f_star
        xdist = np.linalg.norm(x_iter-x_star) ** 2  / num_dim
        xdists += [xdist]
        fdists += [fdist]
        if xdist < xtol:
            break
    return fdists, xdists, x_iter

def ec(x, A, B, compressor, delta, gamma, eta, sigma, num_iter=1000, xtol=0.0):
    # x.shape = (num_dim,)
    # A.shape = (num_nodes, num_dim, num_dim)
    # B.shape = (num_nodes, num_dim)
    def comp(vec):
        return compressor(vec, delta)
    num_nodes, num_dim = B.shape
    x_iter = np.copy(x)
    xdists = [xdistance(x_iter, A, B)]
    fdists = [fdistance(x_iter, A, B)]
    E_iter = np.zeros_like(B)
    f_star = min_f(A, B)
    x_star = argmin_f_noinv(A, B)
    for i in range(0, num_iter):
        sgrad = stoch_gradient(x_iter, A, B, sigma) # shape (num_nodes, num_dim)
        Delta = np.apply_along_axis(comp, 1, eta * E_iter + sgrad )
        E_iter = E_iter + sgrad - Delta
        x_iter = x_iter - gamma * np.mean(Delta, axis = 0)
        fdist = fval(x_iter, A, B) - f_star
        xdist = np.linalg.norm(x_iter-x_star) ** 2  / num_dim
        xdists += [xdist]
        fdists += [fdist]
        if xdist < xtol:
            break
    return fdists, xdists, x_iter

def csgd(x, A, B, compressor, delta, gamma, eta, sigma, num_iter=1000, xtol=0.0):
    # x.shape = (num_dim,)
    # A.shape = (num_nodes, num_dim, num_dim)
    # B.shape = (num_nodes, num_dim)
    def comp(vec):
        return compressor(vec, delta)
    num_nodes, num_dim = B.shape
    x_iter = np.copy(x)
    xdists = [xdistance(x_iter, A, B)]
    fdists = [fdistance(x_iter, A, B)]
    f_star = min_f(A, B)
    x_star = argmin_f_noinv(A, B)
    for i in range(0, num_iter):
        sgrad = stoch_gradient(x_iter, A, B, sigma) # shape (num_nodes, num_dim)
        Delta = np.apply_along_axis(comp, 1, sgrad )
        x_iter = x_iter - gamma * np.mean(Delta, axis = 0)
        fdist = fval(x_iter, A, B) - f_star
        xdist = np.linalg.norm(x_iter-x_star) ** 2  / num_dim
        xdists += [xdist]
        fdists += [fdist]
        if xdist < xtol:
            break
    return fdists, xdists, x_iter

def ef21(x, A, B, compressor, delta, gamma, sigma, num_iter=1000, xtol=0.0):
    # x.shape = (num_dim,)
    # A.shape = (num_nodes, num_dim, num_dim)
    # B.shape = (num_nodes, num_dim)
    def comp(vec):
        return compressor(vec, delta)
    num_nodes, num_dim = B.shape
    x_iter = np.copy(x)
    xdists = [xdistance(x_iter, A, B)]
    fdists = [fdistance(x_iter, A, B)]
    H_iter = np.zeros_like(B)
    f_star = min_f(A, B)
    x_star = argmin_f_noinv(A, B)
    for i in range(0, num_iter):
        sgrad = stoch_gradient(x_iter, A, B, sigma) # shape (num_nodes, num_dim)
        Delta = np.apply_along_axis(comp, 1, sgrad - H_iter)
        x_iter = x_iter - gamma * np.mean(H_iter + Delta, axis = 0)
        H_iter = H_iter + Delta
        fdist = fval(x_iter, A, B) - f_star
        xdist = np.linalg.norm(x_iter-x_star) ** 2 # / num_dim
        xdists += [xdist]
        fdists += [fdist]
        if xdist < xtol:
            break
    return fdists, xdists, x_iter